from torch import nn
import torch


class LCNN(nn.Module):

    def __init__(self):
        super(LCNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=0),
            nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=(2, 2)),
            nn.Conv2d(8, 32, kernel_size=(3, 3), stride=(1, 1), padding=0),
            nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=(2, 2)),
            nn.Conv2d(32, 64, kernel_size=(2, 2), stride=(1, 1), padding=0),
            nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=(2, 2)),
            nn.Flatten(),
            nn.Linear(64 * 64, 32),
            # nn.Dropout(inplace=True),
            nn.Linear(32, 2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.model(x)
        return x


if __name__ == "__main__":
    # 测试模型的正确性
    lcnn = LCNN()
    input = torch.ones((10, 1, 69, 6))
    output = lcnn(input)
    print(output.shape)

